/*
 * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/speculativeDecoding/draftTokenTreeKernels.h"
#include "tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"

namespace th = torch;
namespace tl = tensorrt_llm;
namespace tk = tensorrt_llm::kernels;

namespace torch_ext
{

////////////////////////////////////////////////////////////////////////////////////////////////////////////
std::tuple<th::Tensor, th::Tensor> mtp_prepare_drafter_inputs_op(th::Tensor& inputIds, th::Tensor& seqLens,
    th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& hiddenStates,
    th::Tensor& acceptedTokens, th::Tensor& numAcceptedTokens, th::Tensor& returnInputIds,
    th::Tensor& returnHiddenStates, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest,
    int64_t hiddenSize)
{
    auto dataType = hiddenStates.scalar_type();

    // Check
    auto inputIdsSizes = inputIds.sizes();
    auto hiddenStatesSizes = hiddenStates.sizes();
    TLLM_CHECK(inputIdsSizes[0] == hiddenStatesSizes[0]);

    auto seqLensSizes = seqLens.sizes();
    TLLM_CHECK(seqLensSizes[0] == batchSize);

    auto stream = at::cuda::getCurrentCUDAStream(hiddenStates.get_device());

    // Fill params
    tk::MTPPrepareDrafterInputsParam params;
    params.numMTPModules = numMTPModules;
    params.batchSize = batchSize;
    params.numContextRequest = numContextRequest;
    params.hiddenSize = hiddenSize;
    params.inputIds = reinterpret_cast<int*>(inputIds.data_ptr());
    params.seqLens = reinterpret_cast<int*>(seqLens.data_ptr());
    params.mtpPastHiddenStatesPtrs = reinterpret_cast<void**>(mtpPastHiddenStatesPtrs.data_ptr());
    params.mtpPastTokensPtrs = reinterpret_cast<int**>(mtpPastTokensPtrs.data_ptr());
    params.hiddenStates = reinterpret_cast<void*>(hiddenStates.data_ptr());
    params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());
    params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
    params.returnInputIds = reinterpret_cast<int*>(returnInputIds.data_ptr());
    params.returnHiddenStates = reinterpret_cast<void*>(returnHiddenStates.data_ptr());

    switch (dataType)
    {
    case torch::kFloat16:
        // Handle Float16
        tk::invokeMTPPrepareDrafterInputs<half>(params, stream);
        break;
    case torch::kFloat32:
        // Handle Float32
        tk::invokeMTPPrepareDrafterInputs<float>(params, stream);
        break;
    case torch::kBFloat16:
        // Handle BFloat16
        tk::invokeMTPPrepareDrafterInputs<__nv_bfloat16>(params, stream);
        break;
    default:
        // Handle other data types
        throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16");
        break;
    }

    return std::make_tuple(returnInputIds, returnHiddenStates);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th::Tensor& logits,
    th::Tensor& draftTokens, th::Tensor& targetTokens, int64_t numMTPModules, int64_t batchSize,
    int64_t numContextRequest, int64_t vocabSize)
{
    int const numGenerationRequest = batchSize - numContextRequest;
    auto dataType = logits.scalar_type();

    // Check
    auto logitsSizes = logits.sizes();
    TORCH_CHECK(logitsSizes.size() == 2, "logits must be a 2D Tensor");
    TLLM_CHECK(logitsSizes[0] == (numContextRequest + numGenerationRequest * (numMTPModules + 1)));

    auto draftTokensSizes = draftTokens.sizes();
    TORCH_CHECK(draftTokensSizes.size() == 1);
    TLLM_CHECK(draftTokensSizes[0] == (numGenerationRequest * numMTPModules));

    auto stream = at::cuda::getCurrentCUDAStream(logits.get_device());
    auto acceptedTokens
        = torch::ones({batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
    auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));

    // Fill params
    tk::MTPSampleAndAcceptDraftTokensParam params;
    params.numMTPModules = numMTPModules;
    params.batchSize = batchSize;
    params.numContextRequest = numContextRequest;
    params.vocabSize = vocabSize;
    params.draftTokens = reinterpret_cast<int*>(draftTokens.data_ptr());
    params.targetTokens = reinterpret_cast<int*>(targetTokens.data_ptr());
    params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());
    params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
    params.logits = logits.data_ptr();

    switch (dataType)
    {
    case torch::kFloat16:
        // Handle Float16
        tk::invokeMTPSampleAndAcceptDraftTokens<half>(params, stream);
        break;
    case torch::kFloat32:
        // Handle Float32
        tk::invokeMTPSampleAndAcceptDraftTokens<float>(params, stream);
        break;
    case torch::kBFloat16:
        // Handle BFloat16
        tk::invokeMTPSampleAndAcceptDraftTokens<__nv_bfloat16>(params, stream);
        break;
    default:
        // Handle other data types
        throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16");
        break;
    }

    return std::make_tuple(acceptedTokens, numAcceptedTokens);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
std::tuple<th::Tensor, th::Tensor> mtp_update_hidden_states_op(th::Tensor& inputIds, th::Tensor& seqLens,
    th::Tensor& targetModelHiddenStates, th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs,
    th::Tensor& numAcceptedTokens, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest,
    int64_t hiddenSize)
{
    auto dataType = targetModelHiddenStates.scalar_type();

    // Check
    auto inputIdsSizes = inputIds.sizes();
    auto targetModelHiddenStatesSize = targetModelHiddenStates.sizes();
    TLLM_CHECK(inputIdsSizes[0] == targetModelHiddenStatesSize[0]);

    auto numAcceptedTokensSize = numAcceptedTokens.sizes();
    TLLM_CHECK(numAcceptedTokensSize[0] == batchSize);

    auto stream = at::cuda::getCurrentCUDAStream(targetModelHiddenStates.get_device());

    // Fill params
    tk::MTPUpdateHiddenStatesParam params;
    params.numMTPModules = numMTPModules;
    params.batchSize = batchSize;
    params.numContextRequest = numContextRequest;
    params.hiddenSize = hiddenSize;
    params.inputIds = reinterpret_cast<int*>(inputIds.data_ptr());
    params.seqLens = reinterpret_cast<int*>(seqLens.data_ptr());
    params.targetModelHiddenStates = targetModelHiddenStates.data_ptr();
    params.mtpPastHiddenStatesPtrs = reinterpret_cast<void**>(mtpPastHiddenStatesPtrs.data_ptr());
    params.mtpPastTokensPtrs = reinterpret_cast<int**>(mtpPastTokensPtrs.data_ptr());
    params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());

    switch (dataType)
    {
    case torch::kFloat16:
        // Handle Float16
        tk::invokeMTPUpdateHiddenStates<half>(params, stream);
        break;
    case torch::kFloat32:
        // Handle Float32
        tk::invokeMTPUpdateHiddenStates<float>(params, stream);
        break;
    case torch::kBFloat16:
        // Handle BFloat16
        tk::invokeMTPUpdateHiddenStates<__nv_bfloat16>(params, stream);
        break;
    default:
        // Handle other data types
        throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16");
        break;
    }

    return std::make_tuple(mtpPastHiddenStatesPtrs, mtpPastTokensPtrs);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
std::tuple<th::Tensor, th::Tensor> mtp_relaxed_acceptance_op(th::Tensor& reqSlotIds, th::Tensor& topKValue,
    th::Tensor& topKIndices, th::Tensor& draftTokens, th::Tensor& mtpRelaxedDelta, th::Tensor& numAcceptedTokens,
    th::Tensor& acceptedTokens, int64_t const numMTPModules, int64_t const batchSize, int64_t const numContextRequest,
    int64_t const relaxedTopK, double const relaxedDelta, int64_t const beginThinkingTokens,
    int64_t const endThinkingTokens)
{
    auto dataType = topKValue.scalar_type();

    // Check
    auto numGenerationRequest = batchSize - numContextRequest;

    auto topKValueSizes = topKValue.sizes();
    TLLM_CHECK(topKValueSizes[0] == numGenerationRequest);
    TLLM_CHECK(topKValueSizes[1] == numMTPModules + 1);
    TLLM_CHECK(topKValueSizes[2] == relaxedTopK);

    auto draftTokensSizes = draftTokens.sizes();
    TLLM_CHECK(draftTokensSizes[0] == numGenerationRequest);

    auto numAcceptedTokensSize = numAcceptedTokens.sizes();
    TLLM_CHECK(numAcceptedTokensSize[0] == batchSize);

    auto stream = at::cuda::getCurrentCUDAStream(numAcceptedTokens.get_device());

    // Fill params
    tk::MTPRelaxedAcceptanceParam params;
    params.numMTPModules = numMTPModules;
    params.batchSize = batchSize;
    params.numContextRequest = numContextRequest;
    params.relaxedTopK = relaxedTopK;
    params.relaxedDelta = (float) relaxedDelta;
    params.beginThinkingTokens = beginThinkingTokens;
    params.endThinkingTokens = endThinkingTokens;
    params.reqSlotIds = reinterpret_cast<int*>(reqSlotIds.data_ptr());
    params.topKValue = reinterpret_cast<void*>(topKValue.data_ptr());
    params.topKIndices = reinterpret_cast<int64_t*>(topKIndices.data_ptr());
    params.draftTokens = reinterpret_cast<int*>(draftTokens.data_ptr());
    params.mtpRelaxedDelta = reinterpret_cast<float*>(mtpRelaxedDelta.data_ptr());
    params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
    params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());

    switch (dataType)
    {
    case torch::kFloat16:
        // Handle Float16
        tk::invokeMTPRelaxedAcceptance<half>(params, stream);
        break;
    case torch::kFloat32:
        // Handle Float32
        tk::invokeMTPRelaxedAcceptance<float>(params, stream);
        break;
    case torch::kBFloat16:
        // Handle BFloat16
        tk::invokeMTPRelaxedAcceptance<__nv_bfloat16>(params, stream);
        break;
    default:
        // Handle other data types
        throw std::invalid_argument("Invalid dtype, only supports float16, float32, and bfloat16");
        break;
    }

    return std::make_tuple(acceptedTokens, numAcceptedTokens);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
void extract_real_draft_tokens_op(th::Tensor newDraftTokens, th::Tensor draftTokensBuffer,
    th::Tensor tokensGatherIdxForDrafterModel, th::Tensor topKList, th::Tensor draftTokensIndicesCumsum,
    int64_t curDraftIdx, int64_t batchSize, int64_t maxDraftLen, int64_t maxTotalDraftTokens, int64_t maxTopK)
{
    // Args:
    // curDraftIdx: int
    // batchSize: int
    // maxTotalDraftTokens: int
    // maxTopK: int
    // tokensGatherIdxForDrafterModel: Tensor, int32, indices of the draft tokens that need to be expand this layer
    //     shape: [numTokensExpandThisLayer]
    // topKList: Tensor, int32, top k value for each expandable token
    //     shape: [numTokensExpandThisLayer]
    // draftTokensIndicesCumsum: Tensor, int32, the cumulative sum of the write back indices for each draft layer
    //     shape: [maxDraftLen + 1]
    // newDraftTokens: Tensor, int64, the new draft tokens. We only need to extract this layer's tokens and write back
    // to the draftTokensBuffer
    //     shape: [batchSize, maxTotalDraftTokens + 1 if curDraftIdx > 0 else 1, maxTopK]
    // draftTokensBuffer: Tensor, int64, the buffer to store the real draft tokens
    //     shape: [maxBatchSize, maxTotalDraftTokens + 1]

    // Check the data types
    TLLM_CHECK(tokensGatherIdxForDrafterModel.scalar_type() == torch::kInt32);
    TLLM_CHECK(topKList.scalar_type() == torch::kInt32);
    TLLM_CHECK(draftTokensIndicesCumsum.scalar_type() == torch::kInt32);
    TLLM_CHECK(newDraftTokens.scalar_type() == torch::kInt64);
    TLLM_CHECK(draftTokensBuffer.scalar_type() == torch::kInt64);

    // Check the shape of 'tokensGatherIdxForDrafterModel' and 'topKList'
    auto numTokensExpandThisLayer = tokensGatherIdxForDrafterModel.size(0);
    TLLM_CHECK(numTokensExpandThisLayer > 0);
    TLLM_CHECK(topKList.size(0) == numTokensExpandThisLayer);

    // Check the shape of 'draftTokensIndicesCumsum'
    TLLM_CHECK(draftTokensIndicesCumsum.size(0) == maxDraftLen + 1);

    // Check the shape of 'newDraftTokens'
    TLLM_CHECK(newDraftTokens.size(0) == batchSize);
    if (curDraftIdx == 0)
    {
        TLLM_CHECK(newDraftTokens.size(1) == 1);
        TLLM_CHECK(newDraftTokens.size(2) == maxTopK);
    }
    else
    {
        TLLM_CHECK(newDraftTokens.size(1) == maxTotalDraftTokens + 1);
        TLLM_CHECK(newDraftTokens.size(2) == maxTopK);
    }

    // Check the shape of 'draftTokensBuffer'
    TLLM_CHECK(draftTokensBuffer.size(1) == maxTotalDraftTokens + 1);

    auto stream = at::cuda::getCurrentCUDAStream(newDraftTokens.get_device());

    // Fill params
    tk::ExtractRealDraftTokensParam params;
    params.curDraftIdx = curDraftIdx;
    params.batchSize = batchSize;
    params.maxDraftLen = maxDraftLen;
    params.maxTotalDraftTokens = maxTotalDraftTokens;
    params.maxTopK = maxTopK;
    params.numTokensExpandThisLayer = numTokensExpandThisLayer;
    params.tokensGatherIdxForDrafterModel = reinterpret_cast<int32_t*>(tokensGatherIdxForDrafterModel.data_ptr());
    params.topKList = reinterpret_cast<int32_t*>(topKList.data_ptr());
    params.draftTokensIndicesCumsum = reinterpret_cast<int32_t*>(draftTokensIndicesCumsum.data_ptr());
    params.newDraftTokens = reinterpret_cast<int64_t*>(newDraftTokens.data_ptr());
    params.draftTokensBuffer = reinterpret_cast<int64_t*>(draftTokensBuffer.data_ptr());

    tk::invokeExtractRealDraftTokens(params, stream);
}

} // end namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
    m.def(
        "mtp_prepare_drafter_inputs_op(Tensor inputIds, Tensor seqLens, Tensor "
        "mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor hiddenStates, "
        "Tensor acceptedTokens, Tensor numAcceptedTokens, Tensor returnInputIds, Tensor returnHiddenStates, "
        "int numMTPModules, int batchSize, int numContextRequest,"
        "int hiddenSize) -> (Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
    m.impl("mtp_prepare_drafter_inputs_op", &torch_ext::mtp_prepare_drafter_inputs_op);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
    m.def(
        "mtp_sampling_and_accepted_draft_tokens_op(Tensor logits, Tensor draftTokens, Tensor "
        "targetTokens, int numMTPModules, "
        "int batchSize, int numContextRequest, int vocabSize) -> (Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
    m.impl("mtp_sampling_and_accepted_draft_tokens_op", &torch_ext::mtp_sampling_and_accepted_draft_tokens_op);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
    m.def(
        "mtp_update_hidden_states_op(Tensor inputIds, Tensor seqLens, Tensor targetModelHiddenStates, "
        "Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, "
        "int numMTPModules, int batchSize, int numContextRequest, int hiddenSize) -> (Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
    m.impl("mtp_update_hidden_states_op", &torch_ext::mtp_update_hidden_states_op);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
    m.def(
        "mtp_relaxed_acceptance_op(Tensor reqSlotIds, Tensor topKValue, Tensor topKIndices, Tensor draftTokens, "
        "Tensor mtpRelaxedDelta, Tensor numAcceptedTokens, Tensor acceptedTokens, "
        "int numMTPModules, int batchSize, int numContextRequest, int relaxedTopK, "
        "float relaxedDelta, int beginThinkingTokens, int endThinkingTokens) -> (Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
    m.impl("mtp_relaxed_acceptance_op", &torch_ext::mtp_relaxed_acceptance_op);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
    m.def(
        "extract_real_draft_tokens_op(Tensor newDraftTokens, Tensor draftTokensBuffer, "
        "Tensor tokensGatherIdxForDrafterModel, Tensor topKList, Tensor draftTokensIndicesCumsum, "
        "int curDraftIdx, int batchSize, int maxDraftLen, int maxTotalDraftTokens, int maxTopK) -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
    m.impl("extract_real_draft_tokens_op", &torch_ext::extract_real_draft_tokens_op);
}
